{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finetune BERT-Bahasa"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "This tutorial is available as an IPython notebook at [Malaya/finetune/bert](https://github.com/huseinzol05/Malaya/tree/master/finetune/bert).\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, I will going to show to finetune pretrained BERT-Bahasa using Tensorflow Estimator.\n",
    "\n",
    "TF-Estimator is really a great module created by Tensorflow Team to train a model for a very long period."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3 install bert-tensorflow==1.0.1 tensorflow==1.15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download pretrained model\n",
    "\n",
    "https://github.com/huseinzol05/Malaya/tree/master/pretrained-model/bert#download, In this example, we are going to try BASE size. Just uncomment below to download pretrained model and tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BASE_config.json  bert-base-2020-10-08.tar.gz\r\n",
      "BERT.wordpiece\t  tf-estimator-text-classification.ipynb\r\n",
      "bert-base\r\n"
     ]
    }
   ],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/bert-base-2020-10-08.tar.gz\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/bert/BERT.wordpiece\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/bert/config/BASE_config.json\n",
    "# !tar -zxf bert-base-2020-10-08.tar.gz\n",
    "!ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.ckpt-1000000.data-00000-of-00001\tmodel.ckpt-1000000.meta\r\n",
      "model.ckpt-1000000.index\r\n"
     ]
    }
   ],
   "source": [
    "!ls bert-base"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is a helper function [malaya/finetune/utils.py](https://github.com/huseinzol05/Malaya/blob/master/finetune/utils.py) to help us to train the model on single GPU or multiGPUs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, '../')\n",
    "import utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just going to train on very small news bahasa sentiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Lebih-lebih lagi dengan  kemudahan internet da...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Positive</td>\n",
       "      <td>boleh memberi teguran kepada parti tetapi perl...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Adalah membingungkan mengapa masyarakat Cina b...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Positive</td>\n",
       "      <td>Kami menurunkan defisit daripada 6.7 peratus p...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Ini masalahnya. Bukan rakyat, tetapi sistem</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      label                                               text\n",
       "0  Negative  Lebih-lebih lagi dengan  kemudahan internet da...\n",
       "1  Positive  boleh memberi teguran kepada parti tetapi perl...\n",
       "2  Negative  Adalah membingungkan mengapa masyarakat Cina b...\n",
       "3  Positive  Kami menurunkan defisit daripada 6.7 peratus p...\n",
       "4  Negative        Ini masalahnya. Bukan rakyat, tetapi sistem"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('../sentiment-data-v2.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Negative', 'Positive']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = df['label'].values.tolist()\n",
    "texts = df['text'].values.tolist()\n",
    "unique_labels = sorted(list(set(labels)))\n",
    "unique_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/bert/optimization.py:87: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import bert\n",
    "from bert import run_classifier\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "from bert import modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['Husein', 'Comel', 'tersangat', 'sangatlah']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = tokenization.FullTokenizer(vocab_file = 'BERT.wordpiece', do_lower_case = False)\n",
    "tokens = tokenizer.tokenize('Husein Comel tersangat sangatlah')\n",
    "tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[31560, 17094, 26759, 30559]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_tokens_to_ids(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def token_to_ids(text, maxlen = 512):\n",
    "    tokens_a = tokenizer.tokenize(text)\n",
    "    if len(tokens_a) > maxlen - 2:\n",
    "        tokens_a = tokens_a[:(maxlen - 2)]\n",
    "    tokens = ['[CLS]'] + tokens_a + ['[SEP]']\n",
    "    segment_id = [0] * len(tokens)\n",
    "    input_mask = [1] * len(tokens)\n",
    "    input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    return {'tokens': tokens, 'input_id': input_id,\n",
    "    'input_mask': input_mask, 'segment_id': segment_id}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. `tokens`, tokenized words.\n",
    "2. `input_id`, integer representation of tokenized words, sorted based on wordpiece weightage.\n",
    "3. `input_mask`, attention masking. During training, short words will padded with `0`, so we do not want the model learn padded values as part of the context.\n",
    "4. `segment_id`, Use for text pair classification, in this case, we can simply put `0`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'tokens': ['[CLS]',\n",
       "  'Lebih',\n",
       "  '-',\n",
       "  'lebih',\n",
       "  'lagi',\n",
       "  'dengan',\n",
       "  'kemudahan',\n",
       "  'internet',\n",
       "  'dan',\n",
       "  'laman',\n",
       "  'sosial',\n",
       "  ',',\n",
       "  'taktik',\n",
       "  'ini',\n",
       "  'semakin',\n",
       "  'mudah',\n",
       "  'dikembangkan',\n",
       "  '.',\n",
       "  '[SEP]'],\n",
       " 'input_id': [2,\n",
       "  4015,\n",
       "  17,\n",
       "  2009,\n",
       "  2088,\n",
       "  1822,\n",
       "  5714,\n",
       "  6332,\n",
       "  1766,\n",
       "  3062,\n",
       "  3558,\n",
       "  16,\n",
       "  20153,\n",
       "  1828,\n",
       "  3718,\n",
       "  2766,\n",
       "  20018,\n",
       "  18,\n",
       "  3],\n",
       " 'input_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       " 'segment_id': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_to_ids(texts[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TF-Estimator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TF-Estimator, required 2 parts,\n",
    "\n",
    "1. Input pipeline, https://www.tensorflow.org/api_docs/python/tf/data/Dataset\n",
    "2. Model definition, https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate():\n",
    "    while True:\n",
    "        for i in range(len(texts)):\n",
    "            if len(texts[i]) > 5:\n",
    "                d = token_to_ids(texts[i])\n",
    "                d['label'] = [unique_labels.index(labels[i])]\n",
    "                d.pop('tokens', None)\n",
    "                yield d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': [2,\n",
       "  4015,\n",
       "  17,\n",
       "  2009,\n",
       "  2088,\n",
       "  1822,\n",
       "  5714,\n",
       "  6332,\n",
       "  1766,\n",
       "  3062,\n",
       "  3558,\n",
       "  16,\n",
       "  20153,\n",
       "  1828,\n",
       "  3718,\n",
       "  2766,\n",
       "  20018,\n",
       "  18,\n",
       "  3],\n",
       " 'input_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       " 'segment_id': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       " 'label': [0]}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = generate()\n",
    "next(g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It must a function return a function.\n",
    "\n",
    "```python\n",
    "def get_dataset(batch_size = 32, shuffle_size = 32):\n",
    "    def get():\n",
    "        return dataset\n",
    "    return get\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(batch_size = 32, shuffle_size = 32):\n",
    "    def get():\n",
    "        dataset = tf.data.Dataset.from_generator(\n",
    "            generate,\n",
    "            {'input_id': tf.int32, 'input_mask': tf.int32, 'segment_id': tf.int32, 'label': tf.int32},\n",
    "            output_shapes = {\n",
    "                'input_id': tf.TensorShape([None]),\n",
    "                'input_mask': tf.TensorShape([None]),\n",
    "                'segment_id': tf.TensorShape([None]),\n",
    "                'label': tf.TensorShape([None])\n",
    "            },\n",
    "        )\n",
    "        dataset = dataset.shuffle(shuffle_size)\n",
    "        dataset = dataset.padded_batch(\n",
    "            batch_size,\n",
    "            padded_shapes = {\n",
    "                'input_id': tf.TensorShape([None]),\n",
    "                'input_mask': tf.TensorShape([None]),\n",
    "                'segment_id': tf.TensorShape([None]),\n",
    "                'label': tf.TensorShape([None])\n",
    "            },\n",
    "            padding_values = {\n",
    "                'input_id': tf.constant(0, dtype = tf.int32),\n",
    "                'input_mask': tf.constant(0, dtype = tf.int32),\n",
    "                'segment_id': tf.constant(0, dtype = tf.int32),\n",
    "                'label': tf.constant(0, dtype = tf.int32),\n",
    "            },\n",
    "        )\n",
    "        return dataset\n",
    "    return get"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Test data pipeline using tf.session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-15-2f00f4f10c26>:4: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "iterator = get_dataset()()\n",
    "iterator = iterator.make_one_shot_iterator().get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': <tf.Tensor 'IteratorGetNext:0' shape=(?, ?) dtype=int32>,\n",
       " 'input_mask': <tf.Tensor 'IteratorGetNext:1' shape=(?, ?) dtype=int32>,\n",
       " 'segment_id': <tf.Tensor 'IteratorGetNext:3' shape=(?, ?) dtype=int32>,\n",
       " 'label': <tf.Tensor 'IteratorGetNext:2' shape=(?, ?) dtype=int32>}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': array([[    2,  2009, 12237, ...,     0,     0,     0],\n",
       "        [    2,  3543,  7554, ...,     0,     0,     0],\n",
       "        [    2,  2007,  8065, ...,     0,     0,     0],\n",
       "        ...,\n",
       "        [    2,  3566,  3841, ...,     0,     0,     0],\n",
       "        [    2,  3217,  1011, ...,     0,     0,     0],\n",
       "        [    2,  6009,  4177, ...,     0,     0,     0]], dtype=int32),\n",
       " 'input_mask': array([[1, 1, 1, ..., 0, 0, 0],\n",
       "        [1, 1, 1, ..., 0, 0, 0],\n",
       "        [1, 1, 1, ..., 0, 0, 0],\n",
       "        ...,\n",
       "        [1, 1, 1, ..., 0, 0, 0],\n",
       "        [1, 1, 1, ..., 0, 0, 0],\n",
       "        [1, 1, 1, ..., 0, 0, 0]], dtype=int32),\n",
       " 'segment_id': array([[0, 0, 0, ..., 0, 0, 0],\n",
       "        [0, 0, 0, ..., 0, 0, 0],\n",
       "        [0, 0, 0, ..., 0, 0, 0],\n",
       "        ...,\n",
       "        [0, 0, 0, ..., 0, 0, 0],\n",
       "        [0, 0, 0, ..., 0, 0, 0],\n",
       "        [0, 0, 0, ..., 0, 0, 0]], dtype=int32),\n",
       " 'label': array([[0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [0],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [0],\n",
       "        [1]], dtype=int32)}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It must a function accepts 4 parameters.\n",
    "\n",
    "```python\n",
    "def model_fn(features, labels, mode, params):\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'vocab_size': 32000,\n",
       " 'hidden_size': 768,\n",
       " 'num_hidden_layers': 12,\n",
       " 'num_attention_heads': 12,\n",
       " 'hidden_act': 'gelu',\n",
       " 'intermediate_size': 3072,\n",
       " 'hidden_dropout_prob': 0.1,\n",
       " 'attention_probs_dropout_prob': 0.1,\n",
       " 'max_position_embeddings': 512,\n",
       " 'type_vocab_size': 2,\n",
       " 'initializer_range': 0.02,\n",
       " 'directionality': 'bidi',\n",
       " 'pooler_fc_size': 768,\n",
       " 'pooler_num_attention_heads': 12,\n",
       " 'pooler_num_fc_layers': 3,\n",
       " 'pooler_size_per_head': 128,\n",
       " 'pooler_type': 'first_token_transform'}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file('BASE_config.json')\n",
    "bert_config.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "warmup_proportion = 0.1\n",
    "num_warmup_steps = int(epoch * warmup_proportion)\n",
    "learning_rate = 2e-5\n",
    "init_checkpoint = 'bert-base/model.ckpt-1000000'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn(features, labels, mode, params):\n",
    "    Y = tf.cast(features['label'][:, 0], tf.int32)\n",
    "    \n",
    "    model = modeling.BertModel(\n",
    "        config = bert_config,\n",
    "        is_training = True,\n",
    "        input_ids = features['input_id'],\n",
    "        input_mask = features['input_mask'],\n",
    "        token_type_ids = features['segment_id'],\n",
    "        use_one_hot_embeddings = False,\n",
    "    )\n",
    "    output_layer = model.get_pooled_output()\n",
    "    logits = tf.layers.dense(output_layer, 2)\n",
    "    loss = tf.reduce_mean(\n",
    "        tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits = logits, labels = Y\n",
    "        )\n",
    "    )\n",
    "\n",
    "    tf.identity(loss, 'train_loss')\n",
    "\n",
    "    accuracy = tf.metrics.accuracy(\n",
    "        labels = Y, predictions = tf.argmax(logits, axis = 1)\n",
    "    )\n",
    "    tf.identity(accuracy[1], name = 'train_accuracy')\n",
    "    \n",
    "    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n",
    "    \n",
    "    assignment_map, initialized_variable_names = utils.get_assignment_map_from_checkpoint(\n",
    "        variables, init_checkpoint\n",
    "    )\n",
    "\n",
    "    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        train_op = optimization.create_optimizer(loss, learning_rate, epoch, num_warmup_steps, False)\n",
    "        estimator_spec = tf.estimator.EstimatorSpec(\n",
    "            mode = mode, loss = loss, train_op = train_op\n",
    "        )\n",
    "        \n",
    "    elif mode == tf.estimator.ModeKeys.EVAL:\n",
    "        estimator_spec = tf.estimator.EstimatorSpec(\n",
    "            mode = tf.estimator.ModeKeys.EVAL,\n",
    "            loss = loss,\n",
    "            eval_metric_ops = {'accuracy': accuracy},\n",
    "        )\n",
    "\n",
    "    return estimator_spec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initiate training session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = get_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using config: {'_model_dir': 'finetuned-bert-base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 10, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 1, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fca3eb97080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:**** Trainable Variables ****\n",
      "INFO:tensorflow:  name = bert/embeddings/word_embeddings:0, shape = (32000, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/embeddings/token_type_embeddings:0, shape = (2, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/embeddings/position_embeddings:0, shape = (512, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/embeddings/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/embeddings/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_0/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_1/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_2/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_3/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_4/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_5/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_6/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_7/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_8/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_9/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_10/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/self/query/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/self/query/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/self/key/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/self/key/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/self/value/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/self/value/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/output/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/attention/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/intermediate/dense/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/intermediate/dense/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/output/dense/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/output/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/output/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/encoder/layer_11/output/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/pooler/dense/kernel:0, shape = (768, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = bert/pooler/dense/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = dense/kernel:0, shape = (768, 2)\n",
      "INFO:tensorflow:  name = dense/bias:0, shape = (2,)\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 0 into finetuned-bert-base/model.ckpt.\n",
      "INFO:tensorflow:train_accuracy = 0.34375, train_loss = 0.7432811\n",
      "INFO:tensorflow:loss = 0.7432811, step = 1\n",
      "INFO:tensorflow:global_step/sec: 0.0707289\n",
      "INFO:tensorflow:train_accuracy = 0.4375, train_loss = 1.6084869 (14.139 sec)\n",
      "INFO:tensorflow:loss = 1.6084869, step = 2 (14.138 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.17299\n",
      "INFO:tensorflow:train_accuracy = 0.5416667, train_loss = 0.71116924 (5.781 sec)\n",
      "INFO:tensorflow:loss = 0.71116924, step = 3 (5.781 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.181334\n",
      "WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 3 vs previous value: 3. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n",
      "INFO:tensorflow:train_accuracy = 0.546875, train_loss = 0.6678002 (5.516 sec)\n",
      "INFO:tensorflow:loss = 0.6678002, step = 4 (5.515 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.0801607\n",
      "INFO:tensorflow:train_accuracy = 0.5125, train_loss = 1.4128941 (12.474 sec)\n",
      "INFO:tensorflow:loss = 1.4128941, step = 5 (12.475 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.185281\n",
      "WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 5 vs previous value: 5. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n",
      "INFO:tensorflow:train_accuracy = 0.49479166, train_loss = 1.22251 (5.398 sec)\n",
      "INFO:tensorflow:loss = 1.22251, step = 6 (5.398 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.14771\n",
      "INFO:tensorflow:train_accuracy = 0.4955357, train_loss = 0.75944936 (6.769 sec)\n",
      "INFO:tensorflow:loss = 0.75944936, step = 7 (6.769 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.129142\n",
      "WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 7 vs previous value: 7. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n",
      "INFO:tensorflow:train_accuracy = 0.52734375, train_loss = 0.4374127 (7.745 sec)\n",
      "INFO:tensorflow:loss = 0.4374127, step = 8 (7.745 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.185809\n",
      "INFO:tensorflow:train_accuracy = 0.5590278, train_loss = 0.47080472 (5.380 sec)\n",
      "INFO:tensorflow:loss = 0.47080472, step = 9 (5.380 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 10 into finetuned-bert-base/model.ckpt.\n",
      "INFO:tensorflow:global_step/sec: 0.122564\n",
      "INFO:tensorflow:train_accuracy = 0.5625, train_loss = 0.6999684 (8.159 sec)\n",
      "INFO:tensorflow:loss = 0.6999684, step = 10 (8.160 sec)\n",
      "INFO:tensorflow:Loss for final step: 0.6999684.\n"
     ]
    }
   ],
   "source": [
    "train_hooks = [\n",
    "    tf.train.LoggingTensorHook(\n",
    "        ['train_accuracy', 'train_loss'], every_n_iter = 1\n",
    "    )\n",
    "]\n",
    "utils.run_training(\n",
    "    train_fn = train_dataset,\n",
    "    model_fn = model_fn,\n",
    "    model_dir = 'finetuned-bert-base',\n",
    "    num_gpus = 1,\n",
    "    log_step = 1,\n",
    "    save_checkpoint_step = epoch,\n",
    "    max_steps = epoch,\n",
    "    train_hooks = train_hooks,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
